'''
The only constant ARGS should be used from this module.
The module, at import time, parses the given arguments.
No other things with side effect happens during import.
This is a part of main.initialize()
'''


import argparse
import sys
import time


def make_arg_parser():

    # converts string argument to boolean
    def str2bool(v):
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Boolean value expected.')


    # ARGUMENTS START
    parser = argparse.ArgumentParser()

    # Add arguments below
    base_args = parser.add_argument_group('Base args')
    base_args.add_argument('--run_script')
    base_args.add_argument('--dataset_name', type=str, default='ASSISTments2009',
                           choices=['ASSISTments2009', 'ASSISTments2015', 'ASSISTmentsChall', 'STATICS2011', 'EdNet-KT1'])
    base_args.add_argument('--use_wandb', type=str2bool, default=False)
    base_args.add_argument('--device', type=str, default='cpu')
    base_args.add_argument('--gpu', type=str)
    base_args.add_argument('--num_workers', type=int, default=4)
    base_args.add_argument('--data_root', type=str, default='shared/ASSISTmentsChall')
    base_args.add_argument('--item_info_root', type=str, default='/shared/EdNet-KT1-new/payment_diagnosis_questions.csv')
    base_args.add_argument('--sid_mapper_root', type=str, default='/shared/EdNet-KT1-new/sid_mapper.csv')
    base_args.add_argument('--split_num', type=int, default=1, choices=[1, 2, 3, 4, 5],
                           help="split number for cross-valiadation")
    base_args.add_argument('--train_small_rate', type=float, default=1.0)
    # base_args.add_argument('--vanilla_trained_model_weight_path', type=str)
    # base_args.add_argument('--augmentation_trained_model_weight_path', type=str)
    # base_args.add_argument('--compare_len', type=int, default=20)

    wandb_args = parser.add_argument_group('wandb args')
    wandb_args.add_argument('--project', type=str)
    wandb_args.add_argument('--name', type=str)
    wandb_args.add_argument('--wandb_tags', type=str)

    train_args = parser.add_argument_group('Train args')
    train_args.add_argument('--random_seed', type=int, default=2)
    train_args.add_argument('--num_epochs', type=int, default=100)
    train_args.add_argument('--num_steps', type=int, default=5000)
    train_args.add_argument('--train_batch', type=int, default=128)
    train_args.add_argument('--test_batch', type=int, default=128)
    train_args.add_argument('--eval_steps', type=int, default=100)
    train_args.add_argument('--lr', type=float, default=0.001)
    train_args.add_argument('--min_seq_size', type=int, default=0)
    train_args.add_argument('--seq_size', type=int, default=100)  # 3 should be enough for most testing.
    train_args.add_argument('--max_elapsed_time', type=float, default=300)
    train_args.add_argument('--max_lag_time', type=float, default=300)
    train_args.add_argument('--train_data_frac', type=float, default=None)
    train_args.add_argument('--val_data_frac', type=float, default=None)
    train_args.add_argument('--test_data_frac', type=float, default=None)
    train_args.add_argument('--augment_front', type=float, default=0.0)
    train_args.add_argument('--augment_back', type=float, default=1.0)

    network_args = parser.add_argument_group('Network args')
    network_args.add_argument('--model_type', type=str, default='DKT',
                              choices=['DKT', 'qDKT', 'DKVMN', 'SAKT', 'SAINT'])
    network_args.add_argument('--layer_count', type=int, default=2)
    network_args.add_argument('--head_count', type=int, default=8)
    network_args.add_argument('--embed_sum', type=str2bool, default=False)
    network_args.add_argument('--warm_up_step_count', type=int, default=4000)
    network_args.add_argument('--d_model_count', type=int, default=256)
    network_args.add_argument('--concept_num', type=int, default=64,
                              help="concept number in DKVMN")
    network_args.add_argument('--dropout_rate', type=float, default=0.0)
    network_args.add_argument('--collate_fn', type=str, default='max')
    network_args.add_argument('--lap_weight', type=float, default=0.0,
                              help="weight for laplacian regularizer, qDKT")

    feature_args = parser.add_argument_group('Feature args')
    '''
    item_idx, is_correct, interaction_idx, position, lag_time, elapsed_time, part, tags', type_ofs'
    '''

    feature_args.add_argument('--enc_feature_names', nargs='+', default=['item_idx', 'tags', 'position'])
    feature_args.add_argument('--dec_feature_names', nargs='+', default=['is_correct', 'position'])

    feature_args.add_argument('--enc_feature_dims', nargs='+', default=[256, 256, 256])
    feature_args.add_argument('--dec_feature_dims', nargs='+', default=[256, 256])

    augmentation_args = parser.add_argument_group('Augmentation args')
    augmentation_args.add_argument('--augmentations', nargs='+', default=[],
                                   help="list of augmentation to be applied."
                                        "rep (replacement), ins (insertion), del (deletion)"
                                   )
    augmentation_args.add_argument('--del_prob', type=float, default=0.1)
    augmentation_args.add_argument('--ins_prob', type=float, default=0.1)
    augmentation_args.add_argument('--rep_prob', type=float, default=0.1)

    consistency_args = parser.add_argument_group('Consistency args')
    consistency_args.add_argument('--rep_pred', type=str2bool, default=0)
    consistency_args.add_argument('--rep_only', type=str2bool, default=0)
    consistency_args.add_argument('--rep_type', type=str, default='skill',
                                  help="replacement type, question-random (q-rand), "
                                       "interaction-random (i-rand), and skill-based (skill)",
                                  choices=['q-rand', 'i-rand', 'skill', 'dif-skill'])
    consistency_args.add_argument('--rep_weight', type=float, default=100.0)
    consistency_args.add_argument('--ins_weight', type=float, default=1.0)
    consistency_args.add_argument('--del_weight', type=float, default=1.0)
    consistency_args.add_argument('--ins_type', type=str, default='random',
                                  choices=['random', 'skill'],
                                  help="type of interactions to be inserted.")
    consistency_args.add_argument('--del_type', type=str, default='random',
                                  choices=['random', 'skill'],
                                  help="type of interactions to be deleted.")
    consistency_args.add_argument('--rep_response', type=str, default='all',
                                  choices=['all', '1', '0'])
    consistency_args.add_argument('--ins_response', type=str, default='rand',
                                  choices=['rand', '1', '0'],
                                  help="response for insertion. "
                                       "rand is 50/50 random sampling, and "
                                       "1, 0 are fixed response (correct/incorrect)")
    consistency_args.add_argument('--ins_loss_dir', type=str, default='up',
                                  choices=['up', 'down'],
                                  help="direction of constraints for insertion")
    consistency_args.add_argument('--del_response', type=str, default=None,
                                  choices=[None, '1', '0'],
                                  help="For None, delete any interactions randomly, and "
                                       "for 1 and 0, delete interactions with given responses (correct/incorrect)")
    consistency_args.add_argument('--del_loss_dir', type=str, default='down',
                                  choices=['up', 'down'],
                                  help="direction of constraints for deletion")
    consistency_args.add_argument('--rep_kt_loss', type=str2bool, default='0')
    consistency_args.add_argument('--ins_kt_loss', type=str2bool, default='0')
    consistency_args.add_argument('--del_kt_loss', type=str2bool, default='0')
    consistency_args.add_argument('--rep_cons_loss', type=str2bool, default='1')
    consistency_args.add_argument('--ins_cons_loss', type=str2bool, default='1')
    consistency_args.add_argument('--del_cons_loss', type=str2bool, default='1')
    # ARGUMENTS END

    return parser


def get_args():
    parser = make_arg_parser()
    args = parser.parse_args()

    # parse some arguments and decorate with additional fields

    # set default name if there is none
    if args.name is None:
        args.name = (f'tfm_l{args.layer_count}_dim{args.d_model_count}_'
                     + f'lr{args.lr}_h{args.head_count}_sum{args.embed_sum}'
                     + f'_{int(time.time())}')

    # run_script
    cmd = 'python'
    for e in sys.argv:
        cmd += (' ' + e)
    args.run_script = cmd

    # parse tags
    args.wandb_tags = args.wandb_tags.split(',') if args.wandb_tags is not None else ['test']

    # parse gpus
    if args.gpu is not None:
        args.device = "cuda"
        args.gpu = [int(g) for g in args.gpu.split(',')]  # doesn't support multi-gpu yet
    else:
        args.device = "cpu"

    # locate directory to save log/weight
    args.log_path = f'log/{args.name}.log'
    args.weight_path = f'weight/{args.name}/'

    return args


ARGS = get_args()
